# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generates actions transform a state to new states."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from neural_guided_symbolic_regression.mcts import states


class PolicyBase(object):
  """Base class of policy.

  Subclasses should define the following method:
    * get_new_states_probs
  """

  def get_new_states_probs(self, state):
    """Gets new states and probabilities by applying actions on input state.

    Args:
      state: An object in mcts.states.StateBase. Contains all the information of
          a state.

    Returns:
      new_states: A list of next states. Each state is a result from apply an
          action in the instance attribute actions to the input state.
      action_probs: A float numpy array with shape [num_actions,]. The
          probability of each action in the class attribute actions.
    """
    raise NotImplementedError('Must be implemented by subclass.')


class ProductionRuleAppendPolicy(PolicyBase):
  """Appends a valid production rule on existing list of production rules.

  An new state is generated by appending a production rule in context-free
  grammar to the production rule sequence in the current state. Thus, in
  principle, the number of new states for any state equals to the number of
  unique production rules in the context-free grammar. However, not all the
  production rule is valid to append, so some new states are forbidden.

  Inspired from the encoding and decoding methods in
  "Grammar Variational Autoencoder" (https://arxiv.org/abs/1703.01925),
  the production rule sequence is the preorder traversal of the parsing tree
  of expression. For example, a parsing tree of expression 'a + T' can be

                          S
                          |
                       S '+' T
                       |
                       T
                       |
                      'a'

  The preorder traversal of the above parsing tree is
  S -> S '+' T
  S -> T
  T -> 'a'

  Assuming the grammar is
  S -> S '+' T
  S -> S '-' T
  S -> S '*' T
  S -> S '/' T
  S -> T
  T -> 'a'
  T -> 'b'

  Among all the 7 grammar production rules, the only allowed production rules
  for current state 'a + T' are T -> 'a' and T -> 'b', because the next
  production rule must start with left hand side symbol T according to the
  preorder traversal. Thus, the prior probabilities of the first 5 production
  rules will be nan.
  """

  def __init__(self, grammar):
    """Initializer.

    Args:
      grammar: nltk.grammar.CFG object for context-free grammar.
    """
    self._grammar = grammar

  def get_new_states_probs(self, state):
    """Gets new state from current state by appending a valid production rule.

    Args:
      state: A mcts.states.ProductionRulesState object. Contains a list of
          nltk.grammar.Production objects in attribute
          production_rules_sequence.

    Returns:
      new_states: A list of next states. Each state is a result from apply an
          action in the instance attribute actions to the input state.
      action_probs: A float numpy array with shape [num_actions,]. The
          probability of each action in the class attribute actions.

    Raises:
      TypeError: If input state is not states.ProductionRulesState object.
    """
    if not isinstance(state, states.ProductionRulesState):
      raise TypeError('Input state shoud be an instance of '
                      'states.ProductionRulesState but got %s' % type(state))
    new_states = []
    action_probs = []
    for production_rule in self._grammar.productions():
      if state.is_valid_to_append(production_rule):
        new_state = state.copy()
        new_state.append_production_rule(production_rule)
        new_states.append(new_state)
        action_probs.append(1.)
      else:
        new_states.append(None)
        action_probs.append(np.nan)
    action_probs = np.asarray(action_probs)
    action_probs /= np.nansum(action_probs)
    return new_states, action_probs
